# 引入模块
import time
import torch
import torch.nn as nn
import torch_npu
# 导入AMP模块
from torch_npu.npu import amp
from torch.utils.data import Dataset, DataLoader
import torchvision
from enum import Enum

#指定运行Device，用户请自行定义训练设备
device = torch.device('npu:0')
#torch.npu.set_compile_mode(jit_compile=False)
#option = {"MM_BMM_ND_ENABLE":"disable"}
#torch_npu.npu.set_option(option)
 
 
class Summary(Enum):
    NONE = 0
    AVERAGE = 1
    SUM = 2
    COUNT = 3


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def display_summary(self):
        entries = [" *"]
        entries += [meter.summary() for meter in self.meters]
        print(' '.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
        self.name = name
        self.fmt = fmt
        self.summary_type = summary_type
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def all_reduce(self):
        if torch.npu.is_available():
            device = torch.device("npu")
        elif torch.backends.mps.is_available():
            device = torch.device("mps")
        else:
            device = torch.device("cpu")
        total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
        dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
        self.sum, self.count = total.tolist()
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

    def summary(self):
        fmtstr = ''
        if self.summary_type is Summary.NONE:
            fmtstr = ''
        elif self.summary_type is Summary.AVERAGE:
            fmtstr = '{name} {avg:.3f}'
        elif self.summary_type is Summary.SUM:
            fmtstr = '{name} {sum:.3f}'
        elif self.summary_type is Summary.COUNT:
            fmtstr = '{name} {count:.3f}'
        else:
            raise ValueError('invalid summary type %r' % self.summary_type)

        return fmtstr.format(**self.__dict__)

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def train(train_dataloader, model, criterion, optimizer , epoch):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
    len(train_dataloader),
    [batch_time, data_time, losses, top1, top5],
    prefix="Epoch: [{}]".format(epoch))
    end = time.time()  
    for batch_idx,(imgs, labels) in enumerate(train_dataloader):
        data_time.update(time.time() - end)
        imgs = imgs.to(device)     # 把img数据放到指定NPU上
        labels = labels.to(device) # 把label数据放到指定NPU上
        with amp.autocast():   # 设置amp
            outputs = model(imgs)    # 前向计算
            loss = criterion(outputs, labels)    # 损失函数计算
        optimizer.zero_grad()
        # 进行反向传播前后的loss缩放、参数更新
        scaler.scale(loss).backward()    # loss缩放并反向转播
        scaler.step(optimizer)    # 更新参数（自动unscaling）
        scaler.update()    # 基于动态Loss Scale更新loss_scaling系数
        
        acc1, acc5 = accuracy(outputs, labels, topk=(1, 5))
        losses.update(loss.item(), imgs.size(0))
        top1.update(acc1[0], imgs.size(0))
        top5.update(acc5[0], imgs.size(0))
        batch_time.update(time.time() - end)
        end = time.time()
        if batch_idx % 50 == 0:
            progress.display(batch_idx + 1)
    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
    epoch, batch_idx * len(imgs), len(train_dataloader.dataset),
           100. * batch_idx / len(train_dataloader), loss.item()))


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels = 1, out_channels = 16,
                      kernel_size = (3, 3),
                      stride = (1, 1),
                      padding = 1),
            nn.MaxPool2d(kernel_size = 2),
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(32*7*7, 16),
            nn.ReLU(),
            nn.Linear(16, 10)
        )
    def forward(self, x):
        return self.net(x)

# 数据集获取
train_data = torchvision.datasets.MNIST(
    "./dataset",
    download = True,
    train = True,
    transform = torchvision.transforms.ToTensor()
)
# 定义训练相关参数
batch_size = 64
model = CNN().to(device)  # 把模型放到指定NPU上
train_dataloader = DataLoader(train_data, batch_size=batch_size)    # 定义DataLoader
loss_func = nn.CrossEntropyLoss().to(device)    # 定义损失函数
optimizer = torch.optim.SGD(model.parameters(), lr = 0.1)    # 定义优化器
scaler = amp.GradScaler()    # 在模型、优化器定义之后，定义GradScaler
epochs = 10  # 设置循环次数

for epoch in range(epochs):
    train(train_dataloader, model, loss_func, optimizer, epoch)
torch.save(model,"./mnist.pt")
